{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Wrangle WikiQA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "import datasets\n",
    "import pandas as pd\n",
    "\n",
    "pd.set_option(\"display.max_colwidth\", 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = datasets.load_dataset(\"wiki_qa\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "split = \"train\"\n",
    "raw_df = dataset[split].to_pandas()\n",
    "raw_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rows = []\n",
    "for question_id, group in raw_df.groupby([\"question_id\"], as_index=False):\n",
    "    questions = group[\"question\"].unique().tolist()\n",
    "    assert len(questions) == 1\n",
    "    question = questions[0]\n",
    "    document_titles = group[\"document_title\"].unique().tolist()\n",
    "    assert len(document_titles) == 1\n",
    "    document_title = document_titles[0]\n",
    "    document_text = \" \".join(group[\"answer\"].tolist())\n",
    "    texts_with_emphasis = []\n",
    "    for relevance_label, text in zip(group[\"label\"].to_list(), group[\"answer\"].to_list()):\n",
    "        is_relevant = relevance_label == 1\n",
    "        if is_relevant:\n",
    "            texts_with_emphasis.append(text.upper())\n",
    "        else:\n",
    "            texts_with_emphasis.append(text)\n",
    "    document_text_with_emphasis = \" \".join(texts_with_emphasis)\n",
    "    is_relevant = 1 in group[\"label\"].to_list()\n",
    "    rows.append(\n",
    "        {\n",
    "            \"query_id\": question_id,\n",
    "            \"query_text\": question,\n",
    "            \"document_title\": document_title,\n",
    "            \"document_text\": document_text,\n",
    "            \"document_text_with_emphasis\": document_text_with_emphasis,\n",
    "            \"relevant\": is_relevant,\n",
    "        }\n",
    "    )\n",
    "df = pd.DataFrame(rows)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = f\"wiki_qa-{split}.jsonl\"\n",
    "with open(data_path, \"w\") as f:\n",
    "    for record in df.to_dict(orient=\"records\"):\n",
    "        f.write(json.dumps(record) + \"\\n\")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
